Skip to content

Add anchor and hyperparameter-init controls to posterior_mean_function#829

Draft
kalama-ai wants to merge 16 commits into
feature/posterior_mean_propertyfrom
feature/posterior_mean_configurations
Draft

Add anchor and hyperparameter-init controls to posterior_mean_function#829
kalama-ai wants to merge 16 commits into
feature/posterior_mean_propertyfrom
feature/posterior_mean_configurations

Conversation

@kalama-ai

Copy link
Copy Markdown
Collaborator

Add anchors and mean_kernel_init controls to posterior_mean_function

What this does

The base branch added GaussianProcessSurrogate.posterior_mean_function, which turns a trained GP's posterior mean into a mean module you can plug into a second GP. That gave us mean transfer, but with exactly one fixed behavior. This PR opens up that behavior with two keyword arguments:

def posterior_mean_function(
    self,
    searchspace,
    objective,
    measurements,
    *,
    anchors: Literal["pretrained", "new", "combined"] = "pretrained",
    mean_kernel_init: Literal["freeze", "warmstart", "discard"] = "freeze",
) -> GPyTorchMean

The defaults (anchors="pretrained", mean_kernel_init="freeze") reproduce the existing behavior exactly, so nothing changes unless you opt in.

The two knobs

anchors decides which data the inner GP is conditioned on when computing the transferred mean:

  • "pretrained" — the source GP's own training data (recovered in raw space). This is the "pure transfer" case.
  • "new" — the new GP's measurements only.
  • "combined" — both, concatenated.

mean_kernel_init decides what happens to the inner GP's mean/kernel/likelihood once transfer is set up:

  • "freeze" — deep-copy the pretrained modules and lock them (requires_grad=False). The inner mean is a fixed, static prior.
  • "warmstart" — deep-copy them but leave them trainable, so the outer MLL can keep adjusting them.
  • "discard" — throw away the pretrained hyperparameters and start the inner modules fresh from the factories, trainable.

Importantly, discard resets all three modules (mean, kernel, and likelihood), not just the kernel, because all three feed into the posterior mean.

Guards and a warning

Not every combination is sound, so the method validates up front instead of letting things fail later:

  • anchors="new" + mean_kernel_init="discard" is rejected with a ValueError — it transfers no pretrained information at all, so it would just be a plain GP refit dressed up as transfer learning.
  • mean_kernel_init="warmstart" with anchors in {"new", "combined"} emits a warning. When the inner anchors include the new targets and the inner mean is free to move, the same y_new ends up driving both the inner prior mean and the outer marginal likelihood. The flexible inner mean can then interpolate y_new, the outer residual collapses, and the MLL drives the outer noise toward zero — overconfident posteriors. freeze is always safe here, since a frozen inner mean can't chase y_new.

AdrianSosic and others added 16 commits June 18, 2026 11:37
- add SearchSpace._drop_parameters and a lightweight _ReducedSearchSpace
- preserve parameter metadata and computational-representation column counting without requiring full candidate data
- make GP kernel and fit-criterion factories use the reduced search space interface via n_tasks and _get_n_comp_rep_columns
- add tests covering reduced search space behavior and blocked unsupported access
Introduces `_ReducedSearchSpace` (private `SearchSpace` subclass) that
exposes only parameter information without building the expensive
discrete Cartesian product. This is constructed via
`SearchSpace._drop_parameters(names)`, which returns a reduced version
with the specified parameters removed.

The class blocks access to anything beyond parameter-related properties.

  **Motivation**

When building composite GP kernels, we need to call kernel factories on
search spaces with certain parameters stripped away. A full
`SearchSpace.from_product()` would rebuild the entire discrete candidate
set. `_ReducedSearchSpace` avoids this by storing the parameter tuple
directly and computing everything from it.

  **Changes to `SearchSpace` and kernel factories**

The reduced search space exposes only name-based properties, no integer
indices. To support this, we made two changes to existing code:

- **Replaced `task_idx` sentinel checks** with `n_tasks > 1` / `n_tasks
== 1` in kernel factories and fit criterion factories. (Equivalent
because `TaskParameter` requires a minimum of 2 values).
- **Added `SearchSpace._get_n_comp_rep_columns(selector)`** that returns
the number of comp-rep columns for a parameter selection. This replaces
the previous `len(get_comp_rep_parameter_indices(...))` in
`_get_effective_dimensionality`, avoiding the need to expose
index-returning methods on the reduced space. Uses
`sum(len(p.comp_rep_columns) for p in params)` instead of delegating to
`get_comp_rep_parameter_indices`.

  **Design of `_ReducedSearchSpace`**

- Subclasses `SearchSpace` so it passes type checks and works with
existing factory signatures
- Builds real `SubspaceDiscrete`/`SubspaceContinuous` instances with
zero-row DataFrames so that `parameters`, `parameter_names`, and
`comp_rep_columns` work correctly
- Overrides `__getattribute__` with an allowlist — accessing
`transform`, `index`, etc. raises `AttributeError`

  **Allowed attributes**

  ```python
  _ALLOWED_ATTRIBUTES: ClassVar[frozenset[str]] = frozenset({
      "discrete",
      "continuous",
      "parameters",
      "parameter_names",
      "comp_rep_columns",
      "constraints",
      "type",
      "_task_parameter",
      "n_tasks",
      "_get_n_comp_rep_columns",
      "get_parameters_by_name",
      "_ALLOWED_ATTRIBUTES",
  })
  ```
**Outlook**

This class is a building block for kernel override dispatching in the GP
surrogate. The dispatching logic will:
  - Strip the task parameter from the search space
  - Call kernel factories on the reduced space to obtain a base kernel
  - Separately construct the task kernel
- Multiply both at the gpytorch level after resolving indices against
the full search space
- Ensure that factories called on a reduced search space return BayBE
Kernel objects (which use parameter names) rather than raw gpytorch
kernels (which bake in integer indices)
  - posterior_mean returns a mean factory that can be passed directly to a new GaussianProcessSurrogate via mean_or_factory
  - the new GP normalizes inputs before passing them to the mean module, so the factory undoes that normalization before querying the pretrained GP, which then applies its own normalization internally
  - Raises ModelNotTrainedError if the surrogate has not been fitted yet
  - Replace the posterior_mean property with get_posterior_mean() (with MeanFactoryProtocol signature)
  - method can be passed directly as mean_or_factory to a new GP 
  - Remove _PosteriorMeanFactory from mean.py
  - Move _PosteriorMean class and normalization into the method
- Add output normalization to posterior mean
- override train() on _PosteriorMean to prevent fit_gpytorch_mll from recursively switching nested submodules to training mode, which would change the learned Standardize parameters
- improve tests to use points from posterior mean
Both methods must use identical transform logic: a change in one
(e.g. replacing Normalize with a different input transform) must
automatically apply to the other, or get_posterior_mean silently
produces mismatched results.
- `anchors`: choose pretrained / new / combined data for the inner GP
- `mean_kernel_init`: freeze, warmstart or discard the inner hyperparameters
- Reject the no-op (new, discard) combo; warn on warmstart with new targets
- Split out _resolve_anchors and _build_inner_gp as module-level helpers
- Cover the new behaviour in tests/test_posterior_mean_function.py
@kalama-ai kalama-ai force-pushed the feature/posterior_mean_configurations branch from 30be3b1 to 1ae48b6 Compare June 24, 2026 09:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants